import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import time
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor


def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

def loadSamMode():
    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    device = "cuda"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    return sam

def maskGenerator(sam):
    mask_generator = SamAutomaticMaskGenerator(sam)
    return mask_generator

def maskGenerator2(sam):
    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Requires open-cv to run post-processing
    )
    return mask_generator

def getImgMasks(img,mask_generator,isPth = False):
    image = img
    if isPth:
        image = cv2.imread(img)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    masks = mask_generator.generate(image)
    return masks

def showMasksTxt(masks):
    print(len(masks))
    print(masks[0].keys())
    print(masks[1].keys())
    a = 3
    print(len(masks[a]['segmentation'][0]),len(masks[a]['segmentation']))
    print(masks[a]['area'])
    print(masks[a]['bbox'])
    print(masks[a]['predicted_iou'])
    print(masks[a]['point_coords'])
    print(masks[a]['stability_score'])
    print(masks[a]['crop_box'])

def remakeMasks(masks):
    out = []
    for i,v in enumerate(masks):
        tmp = {'bbox':v['bbox'],
#                'crop_box':v['crop_box'],
               'area':v['area'],
               'predicted_iou':v['predicted_iou'],
               'point_coords':v['point_coords'],
               'stability_score':v['stability_score']}
        out.append(tmp)
    return out
class SMAObj(object):
    """docstring for Stat"""
    def __init__(self,pUse_gpu = True):
        self.isGPU = pUse_gpu
        self.isOcring = False
        self.smaMode = loadSamMode()
        self.mask_generator = maskGenerator(self.smaMode) 
        tmpres = self.smaFile('images/dog.jpg')
        showMasksTxt(tmpres)
    #使用cv图片进行识别
    def smaImg(self,img):
        self.isOcring = True
        ts = time.time()
        result = self.mask_generator.generate(img)
        te = time.time()
        dt = te-ts
        print('cast time:%.3f'%dt)
        # showMasksTxt(result)
        outres = remakeMasks(result)
        self.isOcring = False
        return outres
    #使用图片文件地址识别
    def smaFile(self,fpth):
        self.isOcring = True
        image = cv2.imread(fpth)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        result = self.mask_generator.generate(image)
        self.isOcring = False
        return result
    #是否正在识别图片
    def isOcr(self):
        return self.isOcring

def main():
    samobj = SMAObj()
    mk = samobj.smaFile('images/dog.jpg')
    showMasksTxt(mk)

if __name__ == "__main__":
    main()